// Copyright 2009 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.enterprise.secmgr.common;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.BoundedExecutorService;

import org.joda.time.DateTimeUtils;

import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.security.SecureRandom;
import java.util.Collection;
import java.util.Formatter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;

import javax.annotation.CheckReturnValue;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.ThreadSafe;

/**
 * Utilities useful throughout the security manager.
 */
@ThreadSafe
public class SecurityManagerUtil {
  private static final Logger LOGGER = Logger.getLogger(SecurityManagerUtil.class.getName());

  // don't instantiate
  private SecurityManagerUtil() {
    throw new UnsupportedOperationException();
  }

  /**
   * Remove all elements specified by a given predicate from a collection.
   *
   * @param iterable The collection to modify.
   * @param predicate The predicate identifying the elements to remove.
   * @return A collection of the elements that were removed.
   */
  public static <T> Collection<T> removeInPlace(Iterable<T> iterable, Predicate<T> predicate) {
    ImmutableList.Builder<T> builder = ImmutableList.builder();
    Iterator<T> iterator = iterable.iterator();
    boolean changed = false;
    while (iterator.hasNext()) {
      T element = iterator.next();
      if (predicate.apply(element)) {
        iterator.remove();
        builder.add(element);
      }
    }
    return builder.build();
  }

  /**
   * Annotate a log message with a given session ID.  This should be implemented
   * in the session manager, but can't be due to cyclic build dependencies.
   *
   * @param sessionId The session ID to annotate the message with.
   * @param message The log message to annotate.
   * @return The annotated log message.
   */
  public static String sessionLogMessage(String sessionId, String message) {
    return "sid " + ((sessionId != null) ? sessionId : "?") + ": " + message;
  }

  /**
   * Generate a random nonce as a byte array.
   *
   * @param nBytes The number of random bytes to generate.
   * @return A randomly generated byte array of the given length.
   */
  public static byte[] generateRandomNonce(int nBytes) {
    byte[] randomBytes = new byte[nBytes];
    synchronized (prng) {
      prng.nextBytes(randomBytes);
    }
    return randomBytes;
  }

  /**
   * Generate a random nonce as a hexadecimal string.
   *
   * @param nBytes The number of random bytes to generate.
   * @return A randomly generated hexadecimal string.
   */
  public static String generateRandomNonceHex(int nBytes) {
    return bytesToHex(generateRandomNonce(nBytes));
  }

  private static final SecureRandom prng = new SecureRandom();

  /**
   * Convert a byte array to a hexadecimal string.
   *
   * @param bytes The byte array to convert.
   * @return The equivalent hexadecimal string.
   */
  public static String bytesToHex(byte[] bytes) {
    Preconditions.checkNotNull(bytes);
    Formatter f = new Formatter();
    for (byte b : bytes) {
      f.format("%02x", b);
    }
    return f.toString();
  }

  /**
   * Convert a hexadecimal string to a byte array.
   *
   * @param hexString The hexadecimal string to convert.
   * @return The equivalent array of bytes.
   * @throws IllegalArgumentException if the string isn't valid hexadecimal.
   */
  public static byte[] hexToBytes(String hexString) {
    Preconditions.checkNotNull(hexString);
    int len = hexString.length();
    Preconditions.checkArgument(len % 2 == 0);
    int nBytes = len / 2;
    byte[] decoded = new byte[nBytes];
    int j = 0;
    for (int i = 0; i < nBytes; i += 1) {
      int d1 = Character.digit(hexString.charAt(j++), 16);
      int d2 = Character.digit(hexString.charAt(j++), 16);
      if (d1 < 0 || d2 < 0) {
        throw new IllegalArgumentException("Non-hexadecimal character in string: " + hexString);
      }
      decoded[i] = (byte) ((d1 << 4) + d2);
    }
    return decoded;
  }

  /**
   * Is a given remote "before" time valid?  In other words, is it possible that
   * the remote "before" time is less than or equal to the remote "now" time?
   *
   * @param before A before time from a remote host.
   * @param now The current time on this host.
   * @return True if the before time might not have passed on the remote host.
   */
  public static boolean isRemoteBeforeTimeValid(long before, long now) {
    return before - CLOCK_SKEW_TIME <= now;
  }

  /**
   * Is a given remote "on or after" time valid?  In other words, is it possible
   * that the remote "on or after" time is greater than the remote "now" time?
   *
   * @param onOrAfter An on-or-after time from a remote host.
   * @param now The current time on this host.
   * @return True if the remote time might have passed on the remote host.
   */
  public static boolean isRemoteOnOrAfterTimeValid(long onOrAfter, long now) {
    return onOrAfter + CLOCK_SKEW_TIME > now;
  }

  @VisibleForTesting
  public static long getClockSkewTime() {
    return CLOCK_SKEW_TIME;
  }

  private static final long CLOCK_SKEW_TIME = 5000;

  /**
   * Compare two URLs for equality.  Preferable to using the {@link URL#equals}
   * because the latter calls out to DNS and can block.
   *
   * @param url1 A URL to compare.
   * @param url2 Another URL to compare.
   * @return True if the two URLs are the same.
   */
  public static boolean areUrlsEqual(URL url1, URL url2) {
    if (url1 == null || url2 == null) {
      return url1 == null && url2 == null;
    }
    return areStringsEqualIgnoreCase(url1.getProtocol(), url2.getProtocol())
        && areStringsEqualIgnoreCase(url1.getHost(), url2.getHost())
        && url1.getPort() == url2.getPort()
        && areStringsEqual(url1.getFile(), url2.getFile())
        && areStringsEqual(url1.getRef(), url2.getRef());
  }

  private static boolean areStringsEqual(String s1, String s2) {
    return s1 == s2 || ((s1 == null) ? s2 == null : s1.equals(s2));
  }

  private static boolean areStringsEqualIgnoreCase(String s1, String s2) {
    return s1 == s2 || ((s1 == null) ? s2 == null : s1.equalsIgnoreCase(s2));
  }

  /**
   * @return The value of ENT_CONFIG_NAME from the GSA configuration.
   *   If not running on a GSA (e.g. for testing), return a fixed string.
   */
  public static String getGsaEntConfigName() {
    String entConfigName = System.getProperty("gsa.entityid");
    if (entConfigName == null) {
      return "testing";
    }
    return entConfigName;
  }

  /**
   * @return A URI builder with default scheme and host arguments.
   */
  public static UriBuilder uriBuilder() {
    return new UriBuilder("http", "google.com");
  }

  /**
   * @param scheme The URI Scheme to use.
   * @param host The URI host to use.
   * @return A URI builder with the given scheme and host.
   */
  public static UriBuilder uriBuilder(String scheme, String host) {
    return new UriBuilder(scheme, host);
  }

  /**
   * A class to build URIs by incrementally specifying their path segments.
   */
  public static final class UriBuilder {
    private final String scheme;
    private final String host;
    private final StringBuilder pathBuilder;

    private UriBuilder(String scheme, String host) {
      this.scheme = scheme;
      this.host = host;
      pathBuilder = new StringBuilder();
    }

    /**
     * Add a segment to the path being accumulated.
     *
     * @param segment The segment to add.
     * @return The builder, for convenience.
     * @throws IllegalArgumentException if the segment contains any illegal characters.
     */
    public UriBuilder addSegment(String segment) {
      Preconditions.checkArgument(segment != null && !segment.contains("/"),
          "Path segments may not contain the / character: %s", segment);
      pathBuilder.append("/").append(segment);
      return this;
    }

    /**
     * Add a hex-encoded random segment to the path being accumulated.
     *
     * @param nBytes The number of random bytes in the segment.
     * @return The builder, for convenience.
     */
    public UriBuilder addRandomSegment(int nBytes) {
      return addSegment(generateRandomNonceHex(nBytes));
    }

    /**
     * @return The URI composed of the accumulated parts.
     * @throws IllegalArgumentException if there's a syntax problem with one of the parts.
     */
    public URI build() {
      try {
        return new URI(scheme, host, pathBuilder.toString(), null);
      } catch (URISyntaxException e) {
        throw new IllegalArgumentException(e);
      }
    }
  }

  private static UriBuilder gsaUriBuilder() {
    return uriBuilder()
        .addSegment("enterprise")
        .addSegment("gsa")
        .addSegment(getGsaEntConfigName());
  }

  public static UriBuilder smUriBuilder() {
    return gsaUriBuilder()
        .addSegment("security-manager");
  }

  // TODO(cph): make this configurable (preferably in sec mgr config).
  private static final int THREAD_POOL_SIZE = 20;
  private static final ExecutorService THREAD_POOL
      = Executors.newFixedThreadPool(THREAD_POOL_SIZE);

  // Batches of work that are themselves parallizable use a 2nd pool to
  // parallelize batches while the 1st pool is used for work within the batches.
  private static final ExecutorService THREAD_POOL_2
      = Executors.newFixedThreadPool(THREAD_POOL_SIZE);

  // Timeout difference between THREAD_POOL_1 and THREAD_POOL_2.
  // Otherwise there is a race between their layered use.
  // So when batches are submitted they use THREAD_POOL_2 to manage
  // batches, and THREAD_POOL is given less time to carry out tasks.
  private static final long THREAD_POOL_DELAY_MILLIS = 20;

  @VisibleForTesting
  static int getPrimaryThreadPoolSize() {
    return THREAD_POOL_SIZE;
  }

  /**
   * Runs a bunch of tasks in parallel using the default/primary thread pool.
   *
   * @param callables The tasks to be run.
   * @param timeoutMillis The maximum amount of time allowed for processing all
   *     the tasks.
   * @param sessionId A session ID to use for logging.
   * @return An immutable list of the computed values, in no particular order.
   *     The number of values is normally the same as the number of tasks, but
   *     if the timeoutMillis is reached or if one or more of the tasks generates an
   *     exception, there will be fewer values than tasks.
   */
  @CheckReturnValue
  @Nonnull
  public static <T> List<T> runInParallel(
      @Nonnull Iterable<Callable<T>> callables,
      @Nonnegative long timeoutMillis,
      @Nonnull String sessionId) {
    long endTimeMillis = DateTimeUtils.currentTimeMillis() + timeoutMillis;
    return runInParallel(THREAD_POOL, callables, endTimeMillis, sessionId);
  }

  private static long calcRemainingMillis(long endTimeMillis) {
    return endTimeMillis - DateTimeUtils.currentTimeMillis();
  }

  @CheckReturnValue
  @Nonnull
  public static <T> List<T> runBatchesInParallel(
      @Nonnull Iterable<KeyedBatchOfCallables<T>> keyedBatches, @Nonnegative long timeoutMillis,
      @Nonnull String sessionId, int maxThreadsPerBatch) {
    long endTimeMillis = DateTimeUtils.currentTimeMillis() + timeoutMillis;

    /* Convert each batch of callables (a list of callables) into a single
      callable that has the batch of callables parallelized inside of it */
    List<Callable<List<T>>> callsWithParallization = Lists.newArrayList();
    for (KeyedBatchOfCallables<T> keyedBatch : keyedBatches) {
      Callable<List<T>> oneCallableBatch = keyedBatch
          .toSingleParallelizedCallable(endTimeMillis, sessionId, maxThreadsPerBatch);
      callsWithParallization.add(oneCallableBatch);
    }

    List<List<T>> answerLists = runInParallel(THREAD_POOL_2, callsWithParallization,
        endTimeMillis, sessionId);
    ImmutableList.Builder<T> builder = ImmutableList.builder();
    for (List<T> answerList : answerLists) {
      builder.addAll(answerList);
    }
    return builder.build();
  }

  @Nonnull
  private static <T> List<T> runInParallel(
      @Nonnull ExecutorService threadPool,
      @Nonnull Iterable<Callable<T>> callables,
      @Nonnegative long endTimeMillis,
      @Nonnull String sessionId) {
    Preconditions.checkNotNull(threadPool);
    Preconditions.checkNotNull(callables);
    Preconditions.checkArgument(endTimeMillis >= 0);
    Preconditions.checkNotNull(sessionId);

    List<T> results = Lists.newArrayList();
    try {
      List<Future<T>> futures = threadPool.invokeAll(Lists.newArrayList(callables),
         calcRemainingMillis(endTimeMillis), TimeUnit.MILLISECONDS);
      for (Future<T> f : futures) {
        try {
          if (f.isDone() && !f.isCancelled()) {
            T singleResult = f.get();
            if (null != singleResult) {
              results.add(singleResult);
            }
          }
        } catch (ExecutionException e) {
          LOGGER.log(Level.WARNING,
              SecurityManagerUtil.sessionLogMessage(sessionId, "Exception in worker thread: "),
              e);
        }
      }

    } catch (InterruptedException e) {
      // Reset the interrupt, then fall through to the cleanup code below.
      Thread.currentThread().interrupt();
    }
    return results;
  }

  // Contains bounded executors per key.
  private static HashMap<String, ExecutorService> boundedServicers
      = new HashMap<String, ExecutorService>();

  /** Returns a bounded executor service for a given key, unless
    one doesn't exist already, in which it's constructed with
    maxThreadsPerBatch parameter.  Note that if a bounded executor
    service by a given name already exists then it's provided
    without checking that it uses the same number of maxThreadsPerBatch. */
  private static ExecutorService getServiceForKey(String key, int maxThreadsPerBatch) {
    ExecutorService service;
    synchronized(boundedServicers) {
      if (!boundedServicers.containsKey(key)) {
        boundedServicers.put(key, new BoundedExecutorService(maxThreadsPerBatch,
            /*fair*/ false, THREAD_POOL));
      }
      service = boundedServicers.get(key);
    }
    return service;
  }

  /** Converts a list of Callables into a single Callable that
    parallizes the original list of work. */
  public static class KeyedBatchOfCallables<T> {
    private final String key;
    private final List<Callable<T>> work;

    public KeyedBatchOfCallables(String key, List<Callable<T>> work) {
      this.key = key;
      this.work = ImmutableList.copyOf(work);
    }

    /** Returns callable that when invoked performs all the work
      callables provided in constructor in parallel.  A BoundedExecutorService
      is used to limit the resources that the parallization takes. */
    private Callable<List<T>> toSingleParallelizedCallable(final long endTimeMillis,
        final String sessionId, int maxThreadsPerBatch) {
      final ExecutorService limitedExecutor = getServiceForKey(key, maxThreadsPerBatch);
      Callable<List<T>> singleCallable = new Callable<List<T>>() {
        @Override
        public List<T> call() {
          return runInParallel(limitedExecutor, work,
              endTimeMillis - THREAD_POOL_DELAY_MILLIS, sessionId);
        }
      };
      return singleCallable;
    }
  }
}
